# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs ETC model for Natural Questions."""

import json
import os
import re

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator

from etcmodel.models import modeling
from etcmodel.models.nq import eval_nq_lib
from etcmodel.models.nq import run_nq_lib

tf.compat.v1.disable_v2_behavior()

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string(
    "etc_config_file", None,
    "The config json file corresponding to the pre-trained ETC model. "
    "This specifies the model architecture.")

flags.DEFINE_string(
    "gold_cache_path", None,
    "Path to the gold data cache generated by the make_gold_cache.py script. "
    "(e.g. '/path/to/cache'")

flags.DEFINE_string(
    "output_dir", None,
    "The output directory where the model checkpoints will be written.")

flags.DEFINE_string("train_precomputed_file", None,
                    "Precomputed tf records for training.")

flags.DEFINE_integer("train_num_precomputed", None,
                     "Number of precomputed tf records for training.")

flags.DEFINE_string(
    "predict_precomputed_file", None,
    "Precomputed tf records for predictions corresponding to --predict_file.")

flags.DEFINE_string(
    "init_checkpoint", None,
    "Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_integer(
    "max_seq_length", 4096,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

flags.DEFINE_integer(
    "max_global_seq_length", 230, "The maximum total global sequence length. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

flags.DEFINE_bool("do_train", False, "Whether to run training.")

flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")

flags.DEFINE_bool("do_inference", False, "Whether to run inference.")

flags.DEFINE_string(
    "inference_output_name", "inference",
    "Name to use as part of the file name for inference outputs when "
    "`do_inference=True`.")

flags.DEFINE_integer("train_batch_size", 64, "Total batch size for training.")

flags.DEFINE_integer("predict_batch_size", 8,
                     "Total batch size for predictions.")

flags.DEFINE_enum("optimizer", "adamw", ["adamw", "lamb"],
                  "The optimizer for training.")

flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")

flags.DEFINE_enum(
    "learning_rate_schedule", "poly_decay", ["poly_decay", "inverse_sqrt"],
    "The learning rate schedule to use. The default of "
    "`poly_decay` uses tf.train.polynomial_decay, while "
    "`inverse_sqrt` uses inverse sqrt of time after the warmup.")

flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.")

flags.DEFINE_float("num_train_epochs", 3.0,
                   "Total number of training epochs to perform.")

flags.DEFINE_float(
    "warmup_proportion", 0.1,
    "Proportion of training to perform linear learning rate warmup for. "
    "E.g., 0.1 = 10% of training.")

flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.")

flags.DEFINE_integer("save_checkpoints_steps", 5000,
                     "How often to save the model checkpoint.")

flags.DEFINE_integer("iterations_per_loop", 1000,
                     "How many steps to make in each estimator call.")

flags.DEFINE_integer(
    "grad_checkpointing_period", None,
    "If specified, this overrides the corresponding `EtcConfig` value loaded "
    "from `etc_config_file`.")

flags.DEFINE_integer("random_seed", 0, "Dummy flag used for random restarts.")

flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

tf.flags.DEFINE_string(
    "tpu_name", None,
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
    "url.")

tf.flags.DEFINE_string(
    "tpu_zone", None,
    "[Optional] GCE zone where the Cloud TPU is located in. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

tf.flags.DEFINE_string(
    "gcp_project", None,
    "[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")

flags.DEFINE_bool(
    "verbose_logging", False,
    "If true, all of the warnings related to data processing will be printed. "
    "A number of warnings are expected for a normal NQ evaluation.")

flags.DEFINE_string(
    "tpu_job_name", None,
    "Name of TPU worker binary. Only necessary if job name is changed from"
    " default tpu_worker.")

flags.DEFINE_bool(
    "mask_long_output", False,
    "Whether to mask out invalid positions in the long input as possible "
    "outputs (padding, separator, and question word pieces tokens, which "
    "can never be neither the start nor end of either SA or LA).")

flags.DEFINE_integer(
    "sep_tok_id", 102,
    "Token ID of the [SEP] token, this is only used for determining the mask "
    "used if 'mask_long_output' is True.")

flags.DEFINE_integer(
    "cls_tok_id", 101,
    "Token ID of the [CLS] token, this is only used for determining the mask "
    "used if 'mask_long_output' is True.")

flags.DEFINE_enum(
    "span_selection_method", "joint",
    ["disjoint", "joint", "joint-exhaustive"],
    "The span selection method to use for the inference logic (also affecting "
    "evaluation).")

flags.DEFINE_bool(
    "consider_answer_type", True,
    "Whether to consider the answer type in the inference logic (also "
    "affecting evaluation).")

flags.DEFINE_integer(
    "max_short_answer_len", 32,
    "Maximum short answer length to consider in the inference logic (also "
    "affecting evaluation).")


def _get_global_step_for_checkpoint(checkpoint_path):
  """Returns the global step for the checkpoint path, or -1 if not found."""
  re_match = re.search(r"ckpt-(\d+)$", checkpoint_path)
  return -1 if re_match is None else int(re_match.group(1))


def _make_scalar_summary(tag, value):
  """Returns a TF Summary proto for a scalar summary value.

  Args:
    tag: The name of the summary.
    value: The scalar float value of the summary.

  Returns:
    A TF Summary proto.
  """
  return tf.summary.Summary(
      value=[tf.summary.Summary.Value(tag=tag, simple_value=value)])


def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)
  tf.gfile.MakeDirs(FLAGS.output_dir)

  etc_model_config = modeling.EtcConfig.from_json_file(FLAGS.etc_config_file)
  if FLAGS.grad_checkpointing_period is not None:
    etc_model_config.grad_checkpointing_period = FLAGS.grad_checkpointing_period

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf_estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=tf_estimator.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          tpu_job_name=FLAGS.tpu_job_name,
          per_host_input_for_training=is_per_host))

  train_steps = None
  warmup_steps = None
  if FLAGS.do_train:
    num_train_features = FLAGS.train_num_precomputed
    train_steps = int(num_train_features / FLAGS.train_batch_size *
                      FLAGS.num_train_epochs)
    warmup_steps = int(train_steps * FLAGS.warmup_proportion)

  model_fn = run_nq_lib.model_fn_builder(etc_model_config, train_steps,
                                         warmup_steps, FLAGS)

  estimator = tf_estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  training_done_path = os.path.join(FLAGS.output_dir, "training_done")

  if FLAGS.do_train:
    tf.logging.info("***** Running training on precomputed features *****")
    tf.logging.info("  Num split examples = %d", num_train_features)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", train_steps)
    train_filename = FLAGS.train_precomputed_file
    train_input_fn = run_nq_lib.input_fn_builder(
        input_file=train_filename,
        flags=FLAGS,
        etc_model_config=etc_model_config,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=train_steps)

    # Write file to signal training is done.
    with tf.gfile.GFile(training_done_path, "w") as writer:
      writer.write("\n")

  if FLAGS.do_predict:
    eval_filename = FLAGS.predict_precomputed_file

    evaluator = eval_nq_lib.NQEvaluator(
        FLAGS.gold_cache_path,
        max_short_answer_len=FLAGS.max_short_answer_len)

    tf.logging.info("***** Running predictions on precomputed features *****")
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = run_nq_lib.input_fn_builder(
        input_file=eval_filename,
        flags=FLAGS,
        etc_model_config=etc_model_config,
        is_training=False,
        drop_remainder=False)

    gz = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)

    # Writer for TensorBoard.
    summary_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.output_dir, "eval"))

    for checkpoint_path in tf.train.checkpoints_iterator(
        FLAGS.output_dir, min_interval_secs=15 * 60, timeout=16 * 60 * 60):
      global_step = _get_global_step_for_checkpoint(checkpoint_path)
      processed_examples = 0
      output_file = checkpoint_path + ".predicted-tfrecords"

      with tf.python_io.TFRecordWriter(output_file, options=gz) as writer:
        for prediction in estimator.predict(
            predict_input_fn, yield_single_examples=True):
          if processed_examples % 1000 == 0:
            tf.logging.info("Processing example: %d" % processed_examples)
          run_nq_lib.process_prediction(prediction, writer)
          processed_examples += 1

      # Calculate F1 after all examples have been written.
      try:
        scores = evaluator.evaluate_predicted(
            output_file,
            span_selection_method=FLAGS.span_selection_method,
            consider_answer_type=FLAGS.consider_answer_type)
        long_f1 = scores["long-best-threshold-f1"]
        short_f1 = scores["short-best-threshold-f1"]
        average_f1 = (long_f1 + short_f1) / 2

        summary_writer.add_summary(
            _make_scalar_summary(tag="eval_metrics/long_f1", value=long_f1),
            global_step=global_step)
        summary_writer.add_summary(
            _make_scalar_summary(tag="eval_metrics/short_f1", value=short_f1),
            global_step=global_step)
        summary_writer.add_summary(
            _make_scalar_summary(
                tag="eval_metrics/average_f1", value=average_f1),
            global_step=global_step)
        summary_writer.flush()

        with tf.gfile.GFile(checkpoint_path + ".f1.json", "w") as writer_json:
          writer_json.write(
              json.dumps(
                  dict(
                      long_f1=long_f1, short_f1=short_f1,
                      average_f1=average_f1)))

      except Exception as e:  # pylint: disable=broad-except
        tf.logging.error("Failed to evaluate checkpoint: {}, {}".format(
            checkpoint_path, repr(e)))

      if tf.io.gfile.exists(training_done_path):
        # Break if the checkpoint we just processed is the last one.
        last_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        if last_checkpoint is None:
          continue
        last_global_step = _get_global_step_for_checkpoint(last_checkpoint)
        if global_step == last_global_step:
          break

  if FLAGS.do_inference:
    tf.logging.info("***** Running inference *****")
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    if FLAGS.init_checkpoint is None:
      raise ValueError("Must specify `init_checkpoint` for inference.")

    predict_input_fn = run_nq_lib.input_fn_builder(
        input_file=FLAGS.predict_precomputed_file,
        flags=FLAGS,
        etc_model_config=etc_model_config,
        is_training=False,
        drop_remainder=False)

    processed_examples = 0

    output_file = (
        f"{FLAGS.init_checkpoint}.{FLAGS.inference_output_name}.tfrecords.gz")
    gz = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)

    with tf.python_io.TFRecordWriter(output_file, options=gz) as writer:
      for prediction in estimator.predict(
          predict_input_fn,
          checkpoint_path=FLAGS.init_checkpoint,
          yield_single_examples=True):
        if processed_examples % 1000 == 0:
          tf.logging.info("Processing example: %d" % processed_examples)
        run_nq_lib.process_prediction(prediction, writer)
        processed_examples += 1

    inference_logic = eval_nq_lib.NQInference(
        max_short_answer_len=FLAGS.max_short_answer_len)
    pred_by_document = inference_logic.compute_predictions(output_file)
    official_output = inference_logic.compute_official(
        pred_by_document,
        span_selection_method=FLAGS.span_selection_method,
        consider_answer_type=FLAGS.consider_answer_type)
    with tf.io.gfile.GFile(
        f"{FLAGS.init_checkpoint}.{FLAGS.inference_output_name}.json",
        "w") as writer:
      json.dump(official_output, writer)


if __name__ == "__main__":
  flags.mark_flag_as_required("etc_config_file")
  flags.mark_flag_as_required("output_dir")
  tf.app.run()
